/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/ir/mesh_and_axis.h"

#include <algorithm>
#include <cstdint>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/ADT/STLExtras.h"
#include "xla/array.h"
#include "xla/hlo/ir/tile_assignment.h"
#include "xla/tsl/platform/errors.h"
#include "xla/xla_data.pb.h"

namespace xla {

absl::Status Mesh::ValidateMesh() {
  // TODO(varcho): An empty mesh is valid in Shardy. If support for such meshes
  // is required, update this validation.
  if (device_assignment_.dimensions().empty() || axes_names_.empty()) {
    return absl::InvalidArgumentError("Mesh must have at least one axis.");
  }

  if (device_assignment_.dimensions().size() != axes_names_.size()) {
    return absl::InvalidArgumentError(
        "Number of axes names must match number of dimensions in the device "
        "assignment.");
  }

  absl::flat_hash_set<std::string> seen_axis_names;
  for (const std::string& axis_name : axes_names_) {
    if (!seen_axis_names.insert(axis_name).second) {
      return absl::InvalidArgumentError("Mesh has duplicate axis names.");
    }
  }

  // Validate device ids are permutation of iota in non-iota cases.
  if (device_assignment_.iota().has_value()) {
    return absl::OkStatus();
  }
  std::vector<int64_t> device_ids(device_assignment_.array().begin(),
                                  device_assignment_.array().end());
  for (int64_t device_id : device_ids) {
    if (device_id < 0) {
      return absl::InvalidArgumentError(
          "Mesh device ids must be non-negative.");
    }
  }
  std::vector<int64_t> iota(device_ids.size());
  std::iota(iota.begin(), iota.end(), 0);

  // For non-iota cases the device ids should be a non-identity permutation
  // of iota.
  if (device_ids == iota) {
    return absl::InvalidArgumentError(
        "Non-iota device assignment has iota device id list [0,1,2,3...].");
  }
  absl::c_sort(device_ids);
  if (device_ids != iota) {
    return absl::InvalidArgumentError(
        "Device ids must be a permutation of [0,1,2,3...].");
  }
  return absl::OkStatus();
}

Mesh::Mesh(TileAssignment device_assignment,
           absl::Span<const absl::string_view> axes_names)
    : device_assignment_(std::move(device_assignment)),
      axes_names_(axes_names.begin(), axes_names.end()) {
  CHECK_OK(ValidateMesh());
}

MeshProto Mesh::ToProto() const {
  MeshProto proto;
  int64_t num_axes = axes_names_.size();

  if (num_axes == 0) {
    if (device_assignment_.num_elements() == 0) {
      return MeshProto();
    }
    // Maximal mesh
    // TODO(b/454008727): Validate device_ids_size is 1.
    proto.add_device_ids(*device_assignment_.array().begin());
    return proto;
  }

  std::vector<MeshProto::MeshAxis> axes;
  axes.reserve(num_axes);

  for (auto [name, size] :
       llvm::zip_equal(axes_names_, device_assignment_.dimensions())) {
    MeshProto::MeshAxis axis;
    axis.set_name(name);
    axis.set_size(size);
    axes.push_back(std::move(axis));
  }
  proto.mutable_axes()->Assign(axes.begin(), axes.end());

  std::optional<IotaTileAssignment> iota = device_assignment_.iota();
  // Only add device ids for non-iota cases.
  if (!(iota.has_value() && iota->reshape_dims().size() == 1)) {
    proto.mutable_device_ids()->Assign(device_assignment_.array().begin(),
                                       device_assignment_.array().end());
  }
  return proto;
}

Mesh Mesh::FromProto(const MeshProto& proto) {
  // TODO(b/454008727): Add validators for Mesh and AxisRef FromProto methods.
  if (proto.axes_size() == 0) {
    if (proto.device_ids_size() == 0) {
      return Mesh();
    }
    // Maximal mesh
    // TODO(b/454008727): Validate device_ids_size is 1.
    return Mesh(proto.device_ids(0));
  }

  std::vector<int64_t> mesh_axis_sizes;
  std::vector<absl::string_view> mesh_axis_names;
  mesh_axis_sizes.reserve(proto.axes_size());
  mesh_axis_names.reserve(proto.axes_size());
  for (const auto& axis : proto.axes()) {
    mesh_axis_sizes.push_back(axis.size());
    mesh_axis_names.push_back(axis.name());
  }
  absl::Span<const absl::string_view> mesh_axis_names_span =
      absl::MakeSpan(mesh_axis_names);

  // If device ids are not specified, create a mesh with iota tiling.
  if (proto.device_ids_size() == 0) {
    TileAssignment device_assignment =
        TileAssignment(IotaTileAssignment::Create(mesh_axis_sizes));
    return Mesh(device_assignment, mesh_axis_names_span);
  }
  // Otherwise, create a mesh with the specific device id ordering.
  std::vector<int64_t> device_ids(proto.device_ids().begin(),
                                  proto.device_ids().end());
  Array<int64_t> device_ids_array(mesh_axis_sizes);
  absl::c_copy(device_ids, device_ids_array.begin());

  TileAssignment tile_assignment =
      TileAssignment(std::make_shared<Array<int64_t>>(device_ids_array));
  return Mesh(tile_assignment, mesh_axis_names_span);
}

absl::Status AxisRef::Validate(const Mesh& mesh) const {
  if (mesh_axis_index_ >= mesh.axis_names().size()) {
    return absl::InvalidArgumentError(
        "Axis index must be less than number of axes.");
  }
  if (!sub_axis_info_.has_value()) {
    return absl::OkStatus();
  }

  int64_t axis_size = mesh.axis_size(mesh_axis_index_);
  if (axis_size % sub_axis_info_->pre_size != 0 ||
      axis_size % sub_axis_info_->size != 0) {
    return absl::InvalidArgumentError(
        "Pre-size and size must divide the full axis size.");
  }
  if (sub_axis_info_->size >= axis_size) {
    return absl::InvalidArgumentError(
        "Sub-axis size must be strictly less than the full axis size.");
  }
  return absl::OkStatus();
}

AxisRefProto AxisRef::ToProto() const {
  AxisRefProto proto;
  proto.set_mesh_axis_index(mesh_axis_index_);
  if (sub_axis_info_.has_value()) {
    proto.mutable_sub_axis_info()->set_pre_size(sub_axis_info_->pre_size);
    proto.mutable_sub_axis_info()->set_size(sub_axis_info_->size);
  }
  return proto;
}

AxisRef AxisRef::FromProto(const AxisRefProto& proto) {
  AxisRef axis_ref(proto.mesh_axis_index());
  if (proto.has_sub_axis_info()) {
    axis_ref.sub_axis_info_ = {proto.sub_axis_info().pre_size(),
                               proto.sub_axis_info().size()};
  }
  return axis_ref;
}

AxisRef::AxisRef(int64_t mesh_axis_index) : mesh_axis_index_(mesh_axis_index) {}

AxisRef::AxisRef(int64_t mesh_axis_index, SubAxis sub_axis_info)
    : mesh_axis_index_(mesh_axis_index), sub_axis_info_(sub_axis_info) {
  CHECK_GT(sub_axis_info_->pre_size, 0) << "sub-axis pre-size must be >= 1";
  CHECK_GT(sub_axis_info_->size, 1) << "sub-axis size must be > 1";
}

bool canSubAxesCoexist(int64_t minPreSize, int64_t maxPreSize,
                       int64_t minNextPreSize, int64_t maxNextPreSize) {
  if (minNextPreSize > maxPreSize) {
    // Sub-axes overlap, check if overlapping and non-overlapping parts are
    // valid.
    return minNextPreSize % maxPreSize == 0 && maxPreSize % minPreSize == 0 &&
           maxNextPreSize % minNextPreSize == 0;
  }
  // Sub-axes don't overlap, check if the gap is valid.
  return maxPreSize % minNextPreSize == 0;
}

bool AxisRef::CanCoexist(const AxisRef& other) const {
  if (mesh_axis_index() != other.mesh_axis_index()) {
    return true;
  }
  if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) {
    // If one is a full axis and the other is a sub-axis, they can coexist.
    return true;
  }

  const SubAxis& this_sub_axis = sub_axis_info_.value();
  const SubAxis& other_sub_axis = other.sub_axis_info_.value();

  int64_t this_pre_size = this_sub_axis.pre_size;
  int64_t other_pre_size = other_sub_axis.pre_size;
  int64_t this_next_pre_size = this_sub_axis.next_pre_size();
  int64_t other_next_pre_size = other_sub_axis.next_pre_size();

  auto [min_pre_size, max_pre_size] =
      std::minmax(this_pre_size, other_pre_size);
  auto [min_next_pre_size, max_next_pre_size] =
      std::minmax(this_next_pre_size, other_next_pre_size);

  return canSubAxesCoexist(min_pre_size, max_pre_size, min_next_pre_size,
                           max_next_pre_size);
}

bool AxisRef::Overlaps(const AxisRef& other) const {
  if (mesh_axis_index() != other.mesh_axis_index()) {
    return false;
  }

  // If one is a full axis then they must overlap.
  if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) {
    return true;
  }

  const SubAxis& this_sub_axis = sub_axis_info_.value();
  const SubAxis& other_sub_axis = other.sub_axis_info_.value();

  return this_sub_axis.pre_size < other_sub_axis.next_pre_size() &&
         other_sub_axis.pre_size < this_sub_axis.next_pre_size();
}

bool AxisRef::CanCoexistWithoutOverlap(const AxisRef& other) const {
  // Check if the axes are on different mesh dimensions. If so, they can always
  // coexist and never overlap.
  if (mesh_axis_index() != other.mesh_axis_index()) {
    return true;
  }

  // If one AxisRef is a full axis it will always overlap the other axis on the
  // same dimension.
  if (!sub_axis_info_.has_value() || !other.sub_axis_info_.has_value()) {
    return false;
  }

  const SubAxis& this_sub_axis = sub_axis_info_.value();
  const SubAxis& other_sub_axis = other.sub_axis_info_.value();

  int64_t this_pre_size = this_sub_axis.pre_size;
  int64_t other_pre_size = other_sub_axis.pre_size;
  int64_t this_next_pre_size = this_sub_axis.next_pre_size();
  int64_t other_next_pre_size = other_sub_axis.next_pre_size();

  // Check for overlapping sub-axes
  bool overlaps = (this_next_pre_size > other_pre_size) &&
                  (other_next_pre_size > this_pre_size);
  if (overlaps) {
    return false;
  }
  // Assert that sub-axes can coexist.
  auto [min_pre_size, max_pre_size] =
      std::minmax(this_pre_size, other_pre_size);
  auto [min_next_pre_size, max_next_pre_size] =
      std::minmax(this_next_pre_size, other_next_pre_size);

  // Sub-axes don't overlap, check if the gap is valid.
  return max_pre_size % min_next_pre_size == 0;
}

bool AxesCanCoexistWithoutOverlap(absl::Span<const AxisRef> axes) {
  for (int64_t i = 0; i < axes.size() - 1; ++i) {
    for (int64_t j = i + 1; j < axes.size(); ++j) {
      if (!axes[i].CanCoexistWithoutOverlap(axes[j])) {
        return false;
      }
    }
  }
  return true;
}

absl::Status ValidateSpanOfAxes(absl::Span<const AxisRef> axes,
                                const Mesh& mesh) {
  for (const AxisRef& axis : axes) {
    TF_RETURN_IF_ERROR(axis.Validate(mesh));
  }
  if (!AxesCanCoexistWithoutOverlap(axes)) {
    return absl::InvalidArgumentError("Axes cannot coexist or axes overlap.");
  }
  return absl::OkStatus();
}

}  // namespace xla
